-
Notifications
You must be signed in to change notification settings - Fork 90
fix: make numpy_backend.tile() and jax_backend.tile() consistent #2587
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@ligerlac Thanks for the PR. Today I have been clawing myself out of travel related time dependent TODOs, but I can review this on Thursday (2025-05-22). I haven't looked/thought about this yet, but I assume that this isn't something unique to |
It's more of a patch. You are right, the problem is not unqiue to
last line fails with
We could also patch that in the jax backend. But I guess a more elegant solution would be to make sure that each backend is only receiving arguments of the correct type by calling |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2587 +/- ##
==========================================
- Coverage 98.23% 98.18% -0.05%
==========================================
Files 65 65
Lines 4193 4195 +2
Branches 591 592 +1
==========================================
Hits 4119 4119
- Misses 45 46 +1
- Partials 29 30 +1
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
This fixes a bug in the jax_backend.tile() method. Consider the following minimal example:
The last line fails with
TypeError: tile requires ndarray or scalar arguments, got <class 'list'> at position 0.
. However, it works fine when using the numpy backend. The problem stems from differences betweennp.tile
andjnp.tile
:Unlike
jnp.tile
,np.tile
implicitly converts the input to the correct type.This PR ensures
tensor_in
is ajnp.array
to make the behaviour ofnumpy_backend.tile()
andjax_backend.tile()
consistent.